iT邦幫忙

2023 iThome 鐵人賽

DAY 29
0
AI & Data

機器學習不難嘛系列 第 29

Day29-線性回歸 梯度下降

  • 分享至 

  • xImage
  •  

上篇講到了計算最佳預測線的一個方法暴力破解,今天要講的另一個方法叫做梯度下降(Gradient Descent),它的原理是利用每個梯度,也就是斜率來判斷要繼續往前還是要回頭往後走,有學過微積分的人應該會覺得似曾相識。

首先要計算出w和b的梯度,w的梯度公式是將成本函數的公式對x進行一次微分,也就是(y-wx)²對x進行一次微分,可得到結果為-2x(y-wx),因為當中的係數2之後可以消掉,經果整理後可得到結果w的梯度公式為x*(wx-y)。y的梯度公式則是將w的公式再對x進行一次微分,過程就不贅述了,可得結果(wx-y)。

求出這兩個公式後就可以透過輸入w和b兩個參數計算其斜率了,過程如下:

def compute_gradient(x, y, w, b):
  w_gradient = (x*(w*x + b - y)).sum() / len(x)
  b_gradient = ((w*x +b - y)).mean()

  return w_gradient, b_gradient
compute_gradient(x, y, 0, 0)#w和b兩參數值可自行調整

有人可能沒學過微積分不知道斜率代表什麼,斜率就是線上一個點的下一步是往上還是往下走,我們最後要求的一個臨界點的斜率會是0,也就是微分結果會等於0。

介紹完計算公式直接可以開始設定我們的函數了,下面程式碼首先會看到三個變數,其中的learning_rate可以想成是訓練的步伐大小,如果步伐過大可能會走過頭,太小可能會走不到終點,需要靠測試來調整步伐,接下來的自定義函數參數有x、y、w初始值、b初始值(w和b的值會隨著不斷地執行而產生變化)、學習率、之前寫過的成本函數公式、計算梯度公式、執行次數、執行幾次為一組,函數中會建立三個字典來記錄成本、w、b的變化,迴圈用來指定執行次數和計算梯度,if判斷式用來將一組的資料只顯示一個,並整理輸出的數據格式,最後傳回最終的w和b值

w = 0
b = 0
learning_rate = 0.001

def gradient_descent(x, y, w_init, b_init, learning_rate, cost_funtion, gradient_funtion, run_iter, p_iter=1000):

  c_hist = []
  w_hist = []
  b_hist = []

  w = w_init
  b = b_init

  for i in range(run_iter):
    w_gradient, b_gradient = gradient_funtion(x, y, w, b)
    w = w - w_gradient*learning_rate
    b = b - b_gradient*learning_rate
    cost = cost_funtion(x, y, w, b)

    w_hist.append(w)
    b_hist.append(b)
    c_hist.append(cost)

    if i % p_iter == 0:
      print(f"Iteration {i:5} : Cost {cost: .4e}, w {w: .2e}, b {b: .2e}, w_gradient {w_gradient: .2e}, b_gradient {b_gradient: .2e}")

  return w, b, w_hist, b_hist, c_hist

定義完函數並設定好初始參數就可以送進模型計算出結果了,在結果中可以觀察到我們的成本函數在不斷的下降,直到最後幾乎沒有在變動了,就可以得到我們的w_final和b_final了。

w_init = 90
b_init = -100 
learning_rate = 0.001 #w_init、b_init和learning_rate可以自行調整
run_iter = 20000
w_final, b_final, w_hist, b_hist, c_hist = gradient_descent(x, y, w_init, b_init, learning_rate, compute_cost, compute_gradient, run_iter)

https://ithelp.ithome.com.tw/upload/images/20231010/20162311uUvJD1Zb4X.png

最後印出w_final和b_final就可以求到最後的w和b,兩個值分別為95.99,和-99.56,可以發現和暴力破解的結果相當接近。

print(f"最終(w, b) = ({w_final:.2f}, {b_final:.2f})")

上一篇
Day28-線性回歸 暴力破解
下一篇
Day30-結語
系列文
機器學習不難嘛30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言